import numpy,os,time
from matplotlib import pyplot
from scipy import constants
from jqc import jqc_plot
from diatom import Hamiltonian
from sympy.physics.wigner import wigner_3j,wigner_9j
from matplotlib.collections import LineCollection
from matplotlib.colors import LogNorm,LinearSegmentedColormap
from matplotlib.ticker import (
    AutoLocator, AutoMinorLocator)
from matplotlib import gridspec

Consts = Hamiltonian.RbCs

h = constants.h

jqc_plot.plot_style()
cwd = os.path.dirname(os.path.abspath(__file__))
root = os.path.dirname(cwd)

colours = jqc_plot.colours

grid=gridspec.GridSpec(3,5,width_ratios = [1,0.1,1.,0.05,0.1])

colour_dict_twk_blue = {
    "red" : [(0.0,244.0/255.0,244.0/255.0),
            (0.33,124.0/255.0,124.0/255.0),
            (0.66,0.0,0.0),
            (1.0,0.0,0.0)] ,
    "green" : [(0.0,234.0/255.0,234.0/255.0),
            (0.33,154.0/255.0,154.0/255.0),
            (0.66,70.0/255.0,70.0/255.0),
            (1.0,32.0/255.0,32.0/255.0)]  ,
    "blue" : [(0.0,168.0/255.0,168.0/255.0),
            (0.33,148.0/255.0,148/255.0),
            (0.66,127.0/255.0,127.0/255.0),
            (1.0,58.0/255.0,58.0/255.0)]
}
colour_dict_twk_blue_alpha = colour_dict_twk_blue.copy()
colour_dict_twk_blue_alpha['alpha'] = ((0.0, 0.0,0.0),
                #   (0.25,1.0, 1.0),
                   (0.5, 1.0, 1.0),
                #   (0.75,1.0, 1.0),
                   (1.0, 1.0, 1.0))


RbCs_map_twk_blue = LinearSegmentedColormap("RbCs_map_tweak_blue",
                                                colour_dict_twk_blue_alpha)

pyplot.register_cmap(cmap=RbCs_map_twk_blue)

def make_segments(x, y):
    '''
    Create list of line segments from x and y coordinates, in the correct format
    for LineCollection:
    an array of the form   numlines x (points per line) x 2 (x and y) array
    '''

    points = numpy.array([x, y]).T.reshape(-1, 1, 2)
    segments = numpy.concatenate([points[:-1], points[1:]], axis=1)

    return segments

def colorline(x, y, z=None, cmap=pyplot.get_cmap('copper'),
                norm=pyplot.Normalize(0.0, 1.0), linewidth=3, alpha=1.0,
                legend=False,ax=None):
    '''
    Plot a colored line with coordinates x and y
    Optionally specify colors in the array z
    Optionally specify a colormap, a norm function and a line width
    '''
    if ax == None:
        ax = pyplot.gca()

    # Default colors equally spaced on [0,1]:
    if z is None:
        z = numpy.linspace(0.0, 1.0, len(x))

    # Special case if a single number:
    if not hasattr(z, "__iter__"):#to check for numerical input -- this is a hack
        z = numpy.array([z])

    z = numpy.asarray(z)

    segments = make_segments(x, y)
    lc = LineCollection(segments, array=z, cmap=cmap, norm=norm,
                        linewidth=linewidth,zorder=1.25)

    ax.add_collection(lc)

    return lc

def dipolez(Nmax,d):
    ''' Generates the induced dipole moment for a Rigid rotor '''
    shape = numpy.sum(numpy.array([2*x+1 for x in range(0,Nmax+1)]))
    Dmat = numpy.zeros((shape,shape),dtype= numpy.complex)
    i =0
    j =0
    for N1 in range(0,Nmax+1):
        for M1 in range(N1,-(N1+1),-1):
            for N2 in range(0,Nmax+1):
                for M2 in range(N2,-(N2+1),-1):
                    Dmat[i,j]=d*numpy.sqrt((2*N1+1)*(2*N2+1))*(-1)**(M1)*\
                    wigner_3j(N1,1,N2,-M1,0,M2)*wigner_3j(N1,1,N2,0,0,0)
                    j+=1
            j=0
            i+=1
    return Dmat

Nmax = 5
I1 = 3/2
I2 = 7/2


fig = pyplot.figure("STARK")

ax1 = fig.add_subplot(grid[0,0])
ax2 = fig.add_subplot(grid[0,2],sharey=ax1)

ax1.tick_params(labelbottom=False)
ax2.tick_params(labelbottom=False,labelleft=False)

ax1.text(0,1.05,"+980 MHz",fontsize=15,clip_on=False,transform=ax1.transAxes)
ax2.text(0,1.05,"+998.5 MHz",fontsize=15,clip_on=False,transform=ax2.transAxes)

ax1.text(0.5,1.25,"$E_z$=0 V$\\,$cm$^{-1}$",fontsize=15,clip_on=False,
            transform=ax1.transAxes,horizontalalignment='center')
ax2.text(0.5,1.25,"$E_z$=300 V$\\,$cm$^{-1}$",fontsize=15,clip_on=False,
            transform=ax2.transAxes,horizontalalignment='center')

ax3 = fig.add_subplot(grid[1,0],sharex = ax1,sharey=ax1)
ax4 = fig.add_subplot(grid[1,2],sharex=ax2,sharey=ax3)

ax3.set_ylabel("Transition Frequency (MHz)")

ax3.tick_params(labelbottom=False)
ax4.tick_params(labelbottom=False,labelleft=False)

ax5 = fig.add_subplot(grid[2,0],sharex=ax1,sharey=ax1)
ax6 = fig.add_subplot(grid[2,2],sharex=ax2,sharey=ax5)

ax6.tick_params(labelleft=False)

ax5.set_xlabel("Intensity, $I$ (kW$\\,$cm$^{-2}$)")
ax6.set_xlabel("Intensity, $I$ (kW$\\,$cm$^{-2}$)")

ax5.set_xlim(0,10)
ax6.set_xlim(0,10)

ax1.text(0.81,0.78,"(a)",transform=ax1.transAxes,fontsize=20)
ax2.text(0.81,0.78,"(b)",transform=ax2.transAxes,fontsize=20)

axes = [ax1,ax2,ax3,ax4,ax5,ax6]
fnames = ["B0_0","B0_300","B54_0","B54_300","B90_0","B90_300"]
beta = ["0","0","54.7","54.7","90","90"]

label = ["$\\beta = 0^\\circ$","$\\beta = 0^\\circ$",
        "$\\beta_\\mathrm{magic}$","$\\beta_\\mathrm{magic}$",
        "$\\beta = 90^\\circ$","$\\beta = 90^\\circ$"]

offset = [980,998.5,980,998.5,980,998.5]


fpath = cwd+"\\Data\\"

dz = dipolez(Nmax,1)
dz = numpy.kron(dz,numpy.kron(numpy.identity(int(2*I1+1)),
                    numpy.identity(int(2*I2+1))))
d = numpy.zeros(250,dtype="complex128")

for i in range(len(axes)):
    then = time.time()
    ax = axes[i]
    ax.text(0.05,0.1,label[i],transform=ax.transAxes,fontsize=15)

    data = numpy.genfromtxt(fpath+"Energies\\Fig3_"+fnames[i]+".csv",
                        delimiter=',')
    Intensity = data[0,:]*1e-7
    Transition = 1e-6*(data[1:,:]-data[1,:])/h

    try:
        d = numpy.genfromtxt(fpath+"TDM\\Fig3_"+fnames[i]+".csv",
                        delimiter=',',dtype=numpy.complex128)
        print("Loaded TDM")
    except IOError:
        states = numpy.load(fpath+"\\N5_"+fnames[i]+"_states.npy",mmap_mode='r')
        dz = dipolez(Nmax,1)
        dz = numpy.kron(dz,numpy.kron(numpy.identity(int(2*I1+1)),
                        numpy.identity(int(2*I2+1))))
        d = numpy.einsum('ix,ij,jkx->kx',
            states[:,0,:],dz,states[:,32:32+96+1,:])
        numpy.savetxt(fpath+"\\Ncalc5"+fnames[i]+"_TDMz.csv",d,delimiter=',')
        print("Saved dipole moments to:"+fpath+"\\Ncalc5"+fnames[i]+"_TDMz.csv")

    for j in range(32,96+32):

        ax.plot(Intensity,Transition[j,:]-offset[i],color=colours['sand'],
                            zorder=1)

        cl = colorline(Intensity,Transition[j,:]-offset[i],3*numpy.abs(d[j-32])**2,
            cmap='RbCs_map_tweak_blue',norm=LogNorm(vmin=1e-2,vmax=1.0,clip=True),
            linewidth=2.0,ax=ax)

    now = time.time()
    print(i,"took {:.2f}".format(now-then))

ax1.set_ylim(-0.2,0.8)
col_ax = fig.add_subplot(grid[:,4])
fig.colorbar(cl,cax=col_ax)
col_ax.set_title("$z$",color=colours['blue'])
col_ax.set_ylabel("Relative Transition Strength")
pyplot.tight_layout()
pyplot.subplots_adjust(wspace=0.01,hspace=0.1,top=0.9,bottom=0.15)
pyplot.savefig("fig3.pdf")

pyplot.show()
